/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef itkKdTreeGenerator_hxx
#define itkKdTreeGenerator_hxx

namespace itk::Statistics
{
template <typename TSample>
KdTreeGenerator<TSample>::KdTreeGenerator()
  : m_SourceSample(nullptr)
  , m_Subsample(SubsampleType::New())
  , m_BucketSize(16)
{}

template <typename TSample>
void
KdTreeGenerator<TSample>::PrintSelf(std::ostream & os, Indent indent) const
{
  Superclass::PrintSelf(os, indent);

  os << indent << "Source Sample: ";
  if (m_SourceSample != nullptr)
  {
    os << m_SourceSample << std::endl;
  }
  else
  {
    os << "not set." << std::endl;
  }

  os << indent << "Bucket Size: " << m_BucketSize << std::endl;
  os << indent << "MeasurementVectorSize: " << m_MeasurementVectorSize << std::endl;
}

template <typename TSample>
void
KdTreeGenerator<TSample>::SetSample(TSample * sample)
{
  m_SourceSample = sample;
  m_Subsample->SetSample(sample);
  m_Subsample->InitializeWithAllInstances();
  m_MeasurementVectorSize = sample->GetMeasurementVectorSize();
  NumericTraits<MeasurementVectorType>::SetLength(m_TempLowerBound, m_MeasurementVectorSize);
  NumericTraits<MeasurementVectorType>::SetLength(m_TempUpperBound, m_MeasurementVectorSize);
  NumericTraits<MeasurementVectorType>::SetLength(m_TempMean, m_MeasurementVectorSize);
}

template <typename TSample>
void
KdTreeGenerator<TSample>::SetBucketSize(unsigned int size)
{
  m_BucketSize = size;
}

template <typename TSample>
void
KdTreeGenerator<TSample>::GenerateData()
{
  if (m_SourceSample == nullptr)
  {
    return;
  }

  if (m_Tree.IsNull())
  {
    m_Tree = KdTreeType::New();
    m_Tree->SetSample(m_SourceSample);
    m_Tree->SetBucketSize(m_BucketSize);
  }

  const SubsamplePointer subsample = this->GetSubsample();

  // Sanity check. Verify that the subsample has measurement vectors of the
  // same length as the sample generated by the tree.
  if (this->GetMeasurementVectorSize() != subsample->GetMeasurementVectorSize())
  {
    itkExceptionStringMacro("Measurement Vector Length mismatch");
  }

  MeasurementVectorType lowerBound;
  NumericTraits<MeasurementVectorType>::SetLength(lowerBound, m_MeasurementVectorSize);
  MeasurementVectorType upperBound;
  NumericTraits<MeasurementVectorType>::SetLength(upperBound, m_MeasurementVectorSize);

  for (unsigned int d = 0; d < m_MeasurementVectorSize; ++d)
  {
    lowerBound[d] = NumericTraits<MeasurementType>::NonpositiveMin();
    upperBound[d] = NumericTraits<MeasurementType>::max();
  }

  KdTreeNodeType * root = this->GenerateTreeLoop(0, m_Subsample->Size(), lowerBound, upperBound, 0);
  m_Tree->SetRoot(root);
}

template <typename TSample>
inline auto
KdTreeGenerator<TSample>::GenerateNonterminalNode(unsigned int            beginIndex,
                                                  unsigned int            endIndex,
                                                  MeasurementVectorType & lowerBound,
                                                  MeasurementVectorType & upperBound,
                                                  unsigned int            level) -> KdTreeNodeType *
{
  using NodeType = typename KdTreeType::KdTreeNodeType;

  unsigned int partitionDimension = 0;

  const SubsamplePointer subsample = this->GetSubsample();

  // find most widely spread dimension
  Algorithm::FindSampleBoundAndMean<SubsampleType>(
    subsample, beginIndex, endIndex, m_TempLowerBound, m_TempUpperBound, m_TempMean);

  MeasurementType maxSpread = NumericTraits<MeasurementType>::NonpositiveMin();
  for (unsigned int i = 0; i < m_MeasurementVectorSize; ++i)
  {
    MeasurementType spread = m_TempUpperBound[i] - m_TempLowerBound[i];
    if (spread >= maxSpread)
    {
      maxSpread = spread;
      partitionDimension = i;
    }
  }

  unsigned int medianIndex = (endIndex - beginIndex) / 2;

  //
  // Find the medial element by using the NthElement function
  // based on the STL implementation of the QuickSelect algorithm.
  //
  MeasurementType partitionValue =
    Algorithm::NthElement<SubsampleType>(m_Subsample, partitionDimension, beginIndex, endIndex, medianIndex);

  medianIndex += beginIndex;

  // save bounds for cutting dimension
  MeasurementType dimensionLowerBound = lowerBound[partitionDimension];
  MeasurementType dimensionUpperBound = upperBound[partitionDimension];

  upperBound[partitionDimension] = partitionValue;
  const unsigned int beginLeftIndex = beginIndex;
  const unsigned int endLeftIndex = medianIndex;
  NodeType *         left = GenerateTreeLoop(beginLeftIndex, endLeftIndex, lowerBound, upperBound, level + 1);
  upperBound[partitionDimension] = dimensionUpperBound;

  lowerBound[partitionDimension] = partitionValue;
  const unsigned int beginRightIndex = medianIndex + 1;
  const unsigned int endRightIndex = endIndex;
  NodeType *         right = GenerateTreeLoop(beginRightIndex, endRightIndex, lowerBound, upperBound, level + 1);
  lowerBound[partitionDimension] = dimensionLowerBound;

  using KdTreeNonterminalNodeType = KdTreeNonterminalNode<TSample>;

  auto * nonTerminalNode = new KdTreeNonterminalNodeType(partitionDimension, partitionValue, left, right);

  nonTerminalNode->AddInstanceIdentifier(subsample->GetInstanceIdentifier(medianIndex));

  return nonTerminalNode;
}

template <typename TSample>
inline auto
KdTreeGenerator<TSample>::GenerateTreeLoop(unsigned int            beginIndex,
                                           unsigned int            endIndex,
                                           MeasurementVectorType & lowerBound,
                                           MeasurementVectorType & upperBound,
                                           unsigned int            level) -> KdTreeNodeType *
{
  if (endIndex - beginIndex <= m_BucketSize)
  {
    // numberOfInstances small, make a terminal node
    if (endIndex == beginIndex)
    {
      // return the pointer to empty terminal node
      return m_Tree->GetEmptyTerminalNode();
    }

    auto * ptr = new KdTreeTerminalNode<TSample>();

    for (unsigned int j = beginIndex; j < endIndex; ++j)
    {
      ptr->AddInstanceIdentifier(this->GetSubsample()->GetInstanceIdentifier(j));
    }

    // return a terminal node
    return ptr;
  }
  else
  {
    return this->GenerateNonterminalNode(beginIndex, endIndex, lowerBound, upperBound, level + 1);
  }
}
} // namespace itk::Statistics

#endif
